{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# 3.5 使用因果注意力机制隐藏后续词\n",
        "\n",
        "在本节中，我们将修改标准的自注意力机制，创建一个因果注意力（Causal Attention）机制，这对于后续章节中大语言模型的开发至关重要。\n",
        "\n",
        "因果注意力，也称为遮蔽注意力（masked attention），是自注意力的一种特殊形式。它限制模型在处理任何给定 Token 时，只考虑序列中之前和当前的输入。这与标准的自注意力机制形成对比，后者允许一次访问整个输入序列。\n",
        "\n",
        "因此，在计算注意力得分时，因果注意力机制确保模型只考虑序列中当前 Token 或之前出现的 Token。\n",
        "\n",
        "在类似 GPT 的大语言模型中，为了实现这一点，我们对每个处理的 Token 遮蔽掉输入文本中当前 Token 之后的后续 Token，如图 3.19 所示。\n",
        "\n",
        "**图 3.19 在因果注意力中，我们遮蔽掉对角线以上的注意力权重，以便在计算上下文向量时，大语言模型无法访问后续的Token。例如，在第二行中，对于单词“journey”，我们只保留“Your”（之前的单词）和“journey”（当前位置）的注意力权重。**\n",
        "\n",
        "![3.19](../img/fig-3-19.jpg)\n",
        "\n",
        "正如图 3.19 所示，我们遮蔽了对角线以上的注意力权重，并标准化未遮蔽的注意力权重，使得每一行的注意力权重之和为 1。在下一节中，我们将在代码中实现这种遮蔽和标准化的过程。\n",
        "\n",
        "## 3.5.1 应用因果注意力遮蔽\n",
        "\n",
        "在本节中，我们将在代码中实现因果注意力遮蔽（causal attention mask）。我们从图 3.20 中总结的程序开始。\n",
        "\n",
        "**图 3.20 在因果注意力机制中获取遮蔽的注意力权重矩阵的一种方式是对注意力得分应用 softmax 函数，将对角线以上的元素归零并标准化结果矩阵。**\n",
        "\n",
        "![3.20](../img/fig-3-20.jpg)\n",
        "\n",
        "为了实现图3.20所示的因果注意力遮蔽步骤并获得遮蔽的注意力权重，让我们使用前一节中的注意力得分和权重来编码因果注意力机制。\n",
        "\n",
        "第一步，如图 3.20 所示，我们使用 softmax 函数计算注意力权重，如前几节所做的那样：\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "trusted": true
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],\n",
            "        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],\n",
            "        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],\n",
            "        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],\n",
            "        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],\n",
            "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
            "       grad_fn=<SoftmaxBackward>)\n"
          ]
        }
      ],
      "source": [
        "queries = sa_v2.W_query(inputs)  #A\n",
        "keys = sa_v2.W_key(inputs) \n",
        "attn_scores = queries @ keys.T\n",
        "attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
        "print(attn_weights)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "得到以下的注意力权重：\n",
        "```python\n",
        "tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],\n",
        "        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],\n",
        "        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],\n",
        "        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],\n",
        "        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],\n",
        "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
        "       grad_fn=<SoftmaxBackward0>)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "我们可以使用 PyTorch 的 tril 函数实现图 3.20 中的第二步，创建一个遮蔽，使得对角线以上的值为零："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "trusted": true
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tensor([[1., 0., 0., 0., 0., 0.],\n",
            "        [1., 1., 0., 0., 0., 0.],\n",
            "        [1., 1., 1., 0., 0., 0.],\n",
            "        [1., 1., 1., 1., 0., 0.],\n",
            "        [1., 1., 1., 1., 1., 0.],\n",
            "        [1., 1., 1., 1., 1., 1.]])\n"
          ]
        }
      ],
      "source": [
        "context_length = attn_scores.shape[0]\n",
        "mask_simple = torch.tril(torch.ones(context_length, context_length))\n",
        "print(mask_simple)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "得到的遮蔽如下：\n",
        "```python\n",
        " tensor([[1., 0., 0., 0., 0., 0.],\n",
        "        [1., 1., 0., 0., 0., 0.],\n",
        "        [1., 1., 1., 0., 0., 0.],\n",
        "        [1., 1., 1., 1., 0., 0.],\n",
        "        [1., 1., 1., 1., 1., 0.],\n",
        "        [1., 1., 1., 1., 1., 1.]])\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "现在，我们可以将这个遮蔽与注意力权重相乘，将对角线以上的值归零："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "trusted": true
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
            "        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],\n",
            "        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],\n",
            "        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],\n",
            "        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],\n",
            "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
            "       grad_fn=<MulBackward0>)\n"
          ]
        }
      ],
      "source": [
        "masked_simple = attn_weights*mask_simple\n",
        "print(masked_simple)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "我们可以看到，对角线以上的元素已成功归零：\n",
        "```python\n",
        "tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
        "        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],\n",
        "        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],\n",
        "        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],\n",
        "        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],\n",
        "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
        "       grad_fn=<MulBackward0>)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "图 3.20 中的第三步是重新标准化注意力权重，使每一行的和再次为 1。我们可以通过将每一行中的每个元素除以该行的总和来实现这一点："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "trusted": true
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
            "        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],\n",
            "        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],\n",
            "        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],\n",
            "        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],\n",
            "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
            "       grad_fn=<DivBackward0>)\n"
          ]
        }
      ],
      "source": [
        "row_sums = masked_simple.sum(dim=1, keepdim=True)\n",
        "masked_simple_norm = masked_simple / row_sums\n",
        "print(masked_simple_norm)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "得到了一个注意力权重矩阵，其中对角线以上的注意力权重被归零，并且每行的和为 1：\n",
        "```python\n",
        "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
        "        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],\n",
        "        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],\n",
        "        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],\n",
        "        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],\n",
        "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
        "       grad_fn=<DivBackward0>)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 信息泄露\n",
        "\n",
        "当我们应用遮蔽然后重新标准化注意力权重时，可能会出现后续的 Token（我们打算遮蔽的）的信息仍影响当前的 Token 的情况，这是因为它们的值是 softmax 函数计算的一部分。然而，关键点在于，当我们在遮蔽后重新标准化注意力权重时，我们实际上是在一个更小的子集上重新计算 softmax 函数（因为遮蔽位置不会对 softmax 值有任何贡献）。\n",
        "\n",
        "softmax的数学优雅之处在于，尽管在最初的计算中分母包含了所有位置，但在遮蔽和重新归一化之后，被遮蔽的位置的影响被消除了————它们不会以任何有意义的方式对 softmax 得分产生影响。\n",
        "\n",
        "简而言之，经过遮蔽和重新标准化后，注意力权重的分布就好像一开始只在未遮蔽位置上计算一样。这确保了后续（或其他遮蔽）Token 的信息不会像我们想象的那样泄露。\n",
        "\n",
        "虽然此时技术上可以完成因果注意力的实现，但我们可以利用 softmax 函数的一个数学特性，并以更少的步骤更有效地实现遮蔽注意力权重的计算，如图 3.21 所示。\n",
        "\n",
        "**图3.21 一种在因果注意力中获得掩蔽注意力权重矩阵的更高效方法，是在应用 softmax 函数之前，用负无穷大值遮蔽注意力得分。**\n",
        "\n",
        "![3.21](../img/fig-3-21.jpg)\n",
        "\n",
        "softmax 函数将其输入转换为概率分布。当一行中存在负无穷大（-∞）值时，softmax 函数将其概率视为零。（数学上，这是因为 e^-∞ 趋近于 0。）\n",
        "\n",
        "我们可以通过创建一个对角线以上是 1 的遮蔽，然后将这些 1 替换为负无穷大（-inf）值来实现这种更高效的遮蔽技巧：\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "trusted": true
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],\n",
            "        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],\n",
            "        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],\n",
            "        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],\n",
            "        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],\n",
            "        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],\n",
            "       grad_fn=<MaskedFillBackward0>)\n"
          ]
        }
      ],
      "source": [
        "mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
        "masked = attn_scores.masked_fill(mask.bool(), -torch.inf)\n",
        "print(masked)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "得到以下遮蔽：\n",
        "```python\n",
        "tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],\n",
        "        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],         \n",
        "        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],\n",
        "        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],\n",
        "        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],\n",
        "        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],\n",
        "       grad_fn=<MaskedFillBackward0>)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "现在，我们只需对这些遮蔽结果应用 softmax 函数即可完成："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "trusted": true
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
            "        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],\n",
            "        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],\n",
            "        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],\n",
            "        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],\n",
            "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
            "       grad_fn=<SoftmaxBackward>)\n"
          ]
        }
      ],
      "source": [
        "attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)\n",
        "print(attn_weights)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "可以看出，根据输出，每行的值之和为 1，无需进一步标准化：\n",
        "```python\n",
        "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
        "        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],\n",
        "        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],\n",
        "        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],\n",
        "        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],\n",
        "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
        "       grad_fn=<SoftmaxBackward0>)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "我们现在可以使用修改后的注意力权重通过 context_vec = attn_weights @ values 来计算上下文向量，正如 3.4 节中所做的那样。\n",
        "\n",
        "在下一节中，我们首先将介绍对因果注意力机制的另一个小调整，这对于在训练大语言模型时减少过拟合非常有用。\n",
        "\n",
        "## 3.5.2 通过 Dropout 遮蔽额外的注意力权重\n",
        "\n",
        "在深度学习中，Dropout 是一种技术，即在训练过程中随机忽略选定的隐藏层单元，有效地将它们“丢弃”。这种方法有助于防止过拟合，确保模型不会过度依赖任何特定的隐藏层单元组。需要强调的是，Dropout 仅在训练期间使用，在之后不可以使用。\n",
        "\n",
        "在包括 GPT 在内的 Transformer 架构中，注意力机制中的 Dropout 通常应用于两个特定区域：计算注意力得分之后，或将注意力权重应用于值向量之后。\n",
        "\n",
        "这里，我们将在计算注意力权重后应用 Dropout 遮蔽，如图 3.22 所示，这是实践中更常见的变体。\n",
        "\n",
        "**图 3.22 利用因果注意力遮蔽（左上角），我们应用额外的 Dropout 遮蔽（右上角）来归零额外的注意力权重，以减少训练期间的过拟合。**\n",
        "\n",
        "![3.22](../img/fig-3-22.jpg)\n",
        "\n",
        "在以下代码示例中，我们使用了 50% 的 Dropout 率，这意味着遮蔽掉一半的注意力权重。（在后面的章节中训练 GPT 模型时，我们将使用较低的 Dropout 率，例如 0.1 或 0.2。）\n",
        "\n",
        "在下面的代码中，我们首先将 PyTorch 的 Dropout 实现应用于一个由 1 组成的 6x6 张量，用于示例说明：\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "trusted": true
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tensor([[2., 2., 2., 2., 2., 2.],\n",
            "        [0., 2., 0., 0., 0., 0.],\n",
            "        [0., 0., 2., 0., 2., 0.],\n",
            "        [2., 2., 0., 0., 0., 2.],\n",
            "        [2., 0., 0., 0., 0., 2.],\n",
            "        [0., 2., 0., 0., 0., 0.]])\n"
          ]
        }
      ],
      "source": [
        "torch.manual_seed(123)\n",
        "dropout = torch.nn.Dropout(0.5) #A\n",
        "example = torch.ones(6, 6) #B\n",
        "print(dropout(example))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "可以看到，大约一半的值被归零了：\n",
        "```python\n",
        "tensor([[2., 2., 0., 2., 2., 0.],\n",
        "        [0., 0., 0., 2., 0., 2.],\n",
        "        [2., 2., 2., 2., 0., 2.],\n",
        "        [0., 2., 2., 0., 0., 2.],\n",
        "        [0., 2., 0., 2., 0., 2.],\n",
        "        [0., 2., 2., 2., 2., 0.]])\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "当对注意力权重矩阵应用 50% 的 Dropout 率时，矩阵中一半的元素被随机设为零。为了补偿活跃元素的减少，矩阵中剩余元素的值被放大了 1/0.5 = 2 倍。这种放大对于保持注意力权重的整体平衡至关重要，它能确保在训练和推理阶段，注意力机制的平均影响保持一致。\n",
        "\n",
        "现在，让我们将 Dropout 应用于注意力权重矩阵本身："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "trusted": true
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
            "        [0.0000, 0.8966, 0.0000, 0.0000, 0.0000, 0.0000],\n",
            "        [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],\n",
            "        [0.5517, 0.4921, 0.0000, 0.0000, 0.0000, 0.0000],\n",
            "        [0.4350, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
            "        [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
            "       grad_fn=<MulBackward0>)\n"
          ]
        }
      ],
      "source": [
        "torch.manual_seed(123)\n",
        "print(dropout(attn_weights))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "处理后的注意力权重矩阵中有更多的元素被归零，剩余的元素被重新缩放：\n",
        "```python\n",
        "tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
        "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
        "        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],\n",
        "        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],\n",
        "        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],\n",
        "        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],\n",
        "       grad_fn=<MulBackward0>\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "请注意，由于操作系统的不同，Dropout 的输出结果可能会有所不同；关于这种不一致性的更多信息，您可以在 PyTorch 问题跟踪器上(https://github.com/pytorch/pytorch/issues/121595) 阅读。\n",
        "\n",
        "在了解了因果注意力和 Dropout 遮蔽后，我们将在下一节中开发一个简洁的 Python 类。这个类旨在便于高效地应用这两种技术。\n",
        "\n",
        "## 3.5.3 实现一个紧凑的 causal attention 类\n",
        "\n",
        "在这一节中，我们将因果注意力和 Dropout 技术集成到我们在第 3.4 节开发的 SelfAttention Python 类中。这个类随后将作为在即将到来的章节中开发多头注意力（ multi-head\n",
        " attention ）的模板，这是本章中我们将实现的最后的attention类。\n",
        "\n",
        "但在开始之前，还有一件事要确保，那就是代码能够处理由多个输入组成的批次，以便 CausalAttention 类支持我们在第二章中实现的数据加载器生成的批次输出。\n",
        "\n",
        "为简化模拟这种批次输入，我们复制输入的文本示例："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "trusted": true
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "torch.Size([2, 6, 3])\n"
          ]
        }
      ],
      "source": [
        "batch = torch.stack((inputs, inputs), dim=0)\n",
        "print(batch.shape) #A "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "这将生成一个三维张量，包含 2 个输入文本，每个文本有 6 个 Token，每个 Token 是一个三维嵌入向量：\n",
        "```python\n",
        "torch.Size([2, 6, 3])\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "下面的 CausalAttention 类与我们之前实现的 SelfAttention 类相似，我们只是额外添加了下面代码中突出显示的 Dropout 和因果遮蔽部分：\n",
        "\n",
        "### 清单 3.3 一个紧凑的因果注意力类"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 29,
      "metadata": {
        "trusted": true
      },
      "outputs": [],
      "source": [
        "class CausalAttention(nn.Module):\n",
        "\n",
        "    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):\n",
        "        super().__init__()\n",
        "        self.d_out = d_out\n",
        "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
        "        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
        "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
        "        self.dropout = nn.Dropout(dropout) # New\n",
        "        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New\n",
        "\n",
        "    def forward(self, x):\n",
        "        b, num_tokens, d_in = x.shape # New batch dimension b\n",
        "        keys = self.W_key(x)\n",
        "        queries = self.W_query(x)\n",
        "        values = self.W_value(x)\n",
        "\n",
        "        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
        "        attn_scores.masked_fill_(  # New, _ ops are in-place\n",
        "            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) \n",
        "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
        "        attn_weights = self.dropout(attn_weights) # New\n",
        "\n",
        "        context_vec = attn_weights @ values\n",
        "        return context_vec"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "虽然所有新增代码行和之前小节中的代码相似，但我们现在在 ‘__init__’ 方法中添加了一个 self.register_buffer() 调用。在 PyTorch 中使用 register_buffer 并不是所有情况下都必须的，但在这里有几个优点。例如，当我们在大型语言模型中使用 CausalAttention 类时，缓冲区会随着模型自动移动到适当的设备（CPU或GPU）上，这在后续章节中训练大语言模型时会很有用。这意味着我们不需要手动确保这些张量与模型参数在同一设备上，从而避免设备不匹配错误。\n",
        "\n",
        "我们可以像之前使用 SelfAttention 类 那样使用 CausalAttention 类："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 30,
      "metadata": {
        "trusted": true
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "context_vecs.shape: torch.Size([2, 6, 2])\n"
          ]
        }
      ],
      "source": [
        "torch.manual_seed(123)\n",
        "context_length = batch.shape[1]\n",
        "ca = CausalAttention(d_in, d_out, context_length, 0.0)\n",
        "context_vecs = ca(batch)\n",
        "print(\"context_vecs.shape:\", context_vecs.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "得到的上下文向量是一个三维张量，其中每个 Token 现在由一个二维嵌入表示：\n",
        "```python\n",
        "context_vecs.shape: torch.Size([2, 6, 2])\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "图 3.23 提供了一个能够总结我们到目前为止所完成内容的心智模型。\n",
        "\n",
        "图 3.23 一个概括了本章我们编写的四种不同注意力模块的心智模型。我们从一个简化的注意力机制开始，添加了可训练的权重，然后增加了因果注意力遮蔽。在本章的剩余部分，我们将扩展因果注意力机制，并编写多头注意力机制，这是我们在下一章的大语言模型实现中将使用的最终模块。\n",
        "\n",
        "![3.23](../img/fig-3-23.jpg)\n",
        "\n",
        "如图 3.23 所示，本节中，我们专注于神经网络中因果注意力的概念梳理和实现。在下一节中，我们将进一步扩展这一概念，实现一个多头注意力模块，该模块并行实现多个这样的因果注意力机制。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "trusted": true
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "trusted": true
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "trusted": true
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "trusted": true
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "trusted": true
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "minitorch",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}
